import pandas as pd
import numpy as np
from transformers import pipeline, logging
from datasets import load_dataset
import time
import torch
import warnings
import accelerate
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error()

llama_key = open('./llama_key.txt', 'r').read()
ds = load_dataset('mlburnham/Pol_NLI')
test = ds['test'].to_pandas()
ndocs = 5000
test = test.sample(ndocs, random_state = 1)
timings = []
docs = list(test['premise'])

def time_it(pipe, docs, n=1):
    """
    Benchmark the time taken by a model pipeline.

    Args:
        pipe (callable): The model pipeline to benchmark.
        docs (list): A list of documents that will be passed to the pipe.
        n (int): Number of times to run the benchmark. Defaults to 1.

    Returns:
        dict: A dictionary containing the model name, hardware, average time, DPS (documents per second),
              and standard errors for both metrics if n > 1.
    """
    times = []

    for i in range(n):
        # Start the timer
        start_time = time.time()
        results = pipe(docs, 'This text is about politics.', hypothesis_template='{}')
        # Stop timer
        end_time = time.time()
        # Calculate the elapsed time for this run
        elapsed_time = end_time - start_time
        times.append(elapsed_time)

        print(f"Run {i + 1}/{n} - Elapsed time: {elapsed_time:.2f} seconds")

    # Calculate the average time and DPS
    avg_time = np.mean(times)
    dps = ndocs / avg_time

    # Calculate standard errors if n > 1
    if n > 1:
        time_se = np.std(times, ddof=1) / np.sqrt(n)
        dps_se = (np.std([ndocs / t for t in times], ddof=1) / np.sqrt(n))
    else:
        time_se = None
        dps_se = None

    print(f"Average elapsed time: {avg_time:.2f} seconds")
    print(f"Average DPS: {dps}")

    if n > 1:
        print(f"Standard error (Time): {time_se:.4f} seconds")
        print(f"Standard error (DPS): {dps_se:.4f}")

    torch.mps.empty_cache()

    res = {
        'Model': model.split('/')[-1],
        'Hardware': pipe.device.type,
        'Time': avg_time,
        'Time_SE': time_se,
        'DPS': dps,
        'DPS_SE': dps_se
    }
    return res
    
#######################
## DEBATE Base DeBERTa
#######################
model = "mlburnham/Political_DEBATE_DeBERTa_base_v1.1"
pipe = pipeline("zero-shot-classification", model = model, device = torch.device("mps"), batch_size = 8, torch_dtype = torch.bfloat16)

# once to compile
time_it(pipe, docs, n = 1)
# benchmark
res = time_it(pipe, docs, n = 10)
timings.append(res)

#######################
## DEBATE Large DeBERTa
#######################
model = "mlburnham/Political_DEBATE_DeBERTa_large_v1.1"
pipe = pipeline("zero-shot-classification", model = model, device = torch.device("mps"), batch_size = 8, torch_dtype = torch.bfloat16)

# once to compile
time_it(pipe, docs, n = 1)
# benchmark
res = time_it(pipe, docs, n = 10)
timings.append(res)

#############
## Llama 3.1
#############
torch.mps.empty_cache()
model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipe = pipeline("text-generation", model=model, model_kwargs={"torch_dtype": torch.bfloat16}, device_map='mps', batch_size = 1,
token = llama_key)
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id[0]

user_message = """You are a classifier that can only respond with 1 or 0. I'm going to show you a short text sample and I want you to determine if this text is about politics. Here is the text:
{doc}

If it is true that this text is about politics, return 1. If it is not true that this text is about politics, return 0.
Do not explain your answer, and only return 1 or 0.
"""

messages = [{"role": "user", "content": user_message.format(doc = doc)} for doc in test['premise']]
prompt = [pipe.tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True) for message in messages]

torch.mps.empty_cache()
# Start the timer
start_time = time.time()
results = pipe(prompt, max_new_tokens=2, do_sample=False, return_full_text = False, pad_token_id=pipe.tokenizer.pad_token_id, temperature = 0)
# Stop timer
end_time = time.time()
# Calculate the elapsed time
elapsed_time = end_time - start_time

print(f"Elapsed time: {elapsed_time:.2f} seconds")
print(f"DPS: {ndocs/elapsed_time}")
torch.mps.empty_cache()

timings.append({
                'Model': model.split('/')[-1],
                'Hardware': 'mps',
                'Time': elapsed_time,
                'DPS': ndocs/elapsed_time
            })

mps = pd.DataFrame(timings) # 31 minutes
mps.to_csv('./data/mps_timings.csv', index = False)